# -*- coding:utf8 -*-
"""
Created on 2019/9/26 17:18

@author: minc
"""

from flask import Flask, render_template
from flask.json import jsonify

from pyecharts import options as opts
from pyecharts.charts import Bar
from pyecharts.charts import Line

from random import randrange
import datetime,os
from flask import request
import time
import json
from sklearn import metrics
import numpy as np

import os

from utils import get_option
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
STATIC_DIR = os.path.join(BASE_DIR, 'static')

app = Flask(__name__)

data_file = os.environ['DATA_FILE']
label_file= os.environ['LABEL_FILE']

# 使用index01.html自定义模板文件
@app.route("/")
def show_demo01():
    with open(data_file,'r',encoding='utf-8') as f:
        origin_data = json.load(f)
    with open(label_file, 'r') as f:
        labels = f.read().strip().split('\n')
    for i in range(len(origin_data)):
        origin_data[i]['labels'] = [label for label in labels if label in origin_data[i]['labels']]

    data = []
    for i, origin_item in enumerate(origin_data):
        item = {}
        # if '主诉' in origin_item:
        #     text = f'主诉：{origin_item["主诉"]}\n现病史:{origin_item["现病史"]}\n既往史：{origin_item["既往史"]}'
        # else:
        #     text = origin_item['doc']
        text = ''
        for key in origin_item:
            if type(origin_item[key]) is str:
                text += origin_item[key]


        predict = '#'.join(origin_item['predict'])
        if 'label' in origin_item:
            label = '#'.join(origin_item['label'])
        else:
            label = '#'.join(origin_item['labels'])
        if 'id' in origin_item:
            item['id'] = origin_item['id']
        else:
            item['id'] = i
        item['text'] = text
        item['predict'] = predict
        item['label'] = label
        data.append(item)
    
    id2labels = set()
    for item in data:
        labels = item['label'].split('#')
        id2labels = id2labels | set(labels)
    
    id2labels = list(id2labels)
    
    y = np.zeros((len(data),len(id2labels)))
    y_hat = np.zeros((len(data),len(id2labels)))

    for i,item in enumerate(data):
        labels = item['label'].split('#')
        predicts = item['predict'].split('#')
        for label in labels:
            y[i][id2labels.index(label)] = 1
        for predict in predicts:
            if predict == '':
                continue
            if predict in id2labels:
                y_hat[i][id2labels.index(predict)] = 1

    raw_report = metrics.classification_report(y, y_hat, digits=4,target_names = id2labels)
    # print(raw_report)
    raw_report = raw_report.split('\n')[1:]
    report = []
    for raw_row in raw_report:
        raw_row = raw_row.strip()
        if raw_row == '':
            continue

        raw_row = raw_row.split('  ')
        row = list(filter(lambda x:x!='',raw_row))

        # print(row)
        if len(row) < 5:
            continue
        item = {}
        item['name'] = row[0]
        item['precision'] = row[1]
        item['recall'] = row[2]
        item['f1_score'] = row[3]
        item['support'] = row[4]
        report.append(item)

    return render_template("index.html",data = data,report=report)

@app.route("/item",methods=['GET'])
def get_emr_info():
    emr_id = request.args.get('id')

    with open(data_file,'r',encoding='utf-8') as f:
        origin_data = json.load(f)
    with open(label_file, 'r') as f:
        labels = f.read().strip().split('\n')
    for i in range(len(origin_data)):
        origin_data[i]['labels'] = [label for label in labels if label in origin_data[i]['labels']]

    emr = None
    for i, origin_item in enumerate(origin_data):
        if 'id' in origin_item and int(emr_id) != origin_item['id']:
            continue
        if 'id' not in origin_item and int(emr_id) != i:
            continue
            
        emr = []
        
        for key,value in origin_item.items():
            emr.append({"key":key,"value":value})
        
        if 'nodes' in origin_item:
            nodes = origin_item['nodes']
            triples = []
            for path in origin_item['paths']:
                for i in range(0,len(path),2):
                    if len(path[i:i+3]) != 3:
                        continue
                    triples.append(path[i:i+3])
            graph_option = get_option(nodes,triples)
            return render_template("item.html",emr = emr ,graph_option = graph_option)
        else:
            graph_option = get_option([],[])
            return render_template("item.html",emr = emr ,graph_option = graph_option)

@app.route('/search',methods=['GET'])
def search():

    raw_query = request.args.get('query')
    query_id = None
    query_predict = None
    query_label = None
    queries = raw_query.split('##')

    for query in queries:
        value = query.split(':')[1]
        if 'id' in query:
            query_id = value
        if 'predict' in query:
            query_predict = value
        if 'label' in query:
            query_label = value
    with open(data_file,'r',encoding='utf-8') as f:
        origin_data = json.load(f)
    with open(label_file, 'r') as f:
        labels = f.read().strip().split('\n')
    for i in range(len(origin_data)):
        origin_data[i]['labels'] = [label for label in labels if label in origin_data[i]['labels']]

    data = []
    for i, origin_item in enumerate(origin_data):
        item = {}
        text = ''
        for key in origin_item:
            if type(origin_item[key]) is str:
                text += origin_item[key]

        predict = '#'.join(origin_item['predict'])
        label = '#'.join(origin_item['labels'])
        item['id'] = origin_item.get('id', i)
        item['text'] = text
        item['predict'] = predict
        item['label'] = label
        flag = 0
        if query_id is None or query_id == item['id']:
            flag += 1
        if query_predict is None or query_predict in item['predict']:
            flag += 1
        if query_label is None or query_label in item['label']:
            flag += 1
        if flag == 3:
            data.append(item)
        
    return jsonify({"msg":"提交成功","code":200,"data":data})

if __name__ == "__main__":
    app.run(debug=False, host='0.0.0.0', port=8004)
    # print(BASE_DIR)
    # print(randrange(50, 80))